#!/usr/bin/env python3
import configparser
import contextlib
import json
import logging
import os
import subprocess
import sys
from collections import defaultdict
from collections.abc import Mapping
from datetime import datetime, timedelta, timezone
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Protocol
from urllib.parse import parse_qs, urlsplit


class GaugeMetric(Protocol):
    def __call__(self, value: float, labels: Mapping[str, str] | None = None, /) -> None:
        pass


class WalGPrometheusServer(BaseHTTPRequestHandler):
    METRIC_PREFIX = "wal_g_backup_"

    metrics: dict[str, list[str]] = {}
    metric_values: dict[str, dict[str, str]] = defaultdict(dict)

    server_version = "wal-g-prometheus-server/1.0"

    def gauge(
        self, name: str, description: str | None = None, default_value: float | None = None
    ) -> GaugeMetric:
        if name in self.metrics:
            raise ValueError(f"Redefinition of {name} metric")
        self.metrics[name] = [f"# TYPE {self.METRIC_PREFIX}{name} gauge"]
        if description is not None:
            self.metrics[name].append(f"# HELP {self.METRIC_PREFIX}{name} {description}")

        def inner(value: float, labels: Mapping[str, str] | None = None) -> None:
            label_str = ""
            if labels:
                label_str = "{" + ",".join(f'{k}="{v}"' for k, v in labels.items()) + "}"
            self.metric_values[name][label_str] = f"{self.METRIC_PREFIX}{name}{label_str} {value}"

        if default_value is not None:
            inner(default_value)
        return inner

    def print_metrics(self) -> None:
        lines = []
        for metric_name in self.metrics:
            if self.metric_values[metric_name]:
                # Print preamble
                lines += self.metrics[metric_name]
                lines += self.metric_values[metric_name].values()
                lines.append("")
        self.wfile.write("\n".join(lines).encode())

    def do_GET(self) -> None:
        if urlsplit(self.path).path != "/metrics":
            self.send_response(404)
            self.end_headers()
            sys.stderr.flush()
            return

        self.send_response(200)
        self.send_header("Content-type", "text/plain; version=0.0.4")
        self.end_headers()

        self.metrics = {}
        self.metric_values = defaultdict(dict)

        backup_ok = self.gauge("ok", "If wal-g backup-list was OK", 0)
        backup_count = self.gauge("count", "Number of backups")
        backup_earliest_age_seconds = self.gauge("earliest_age_seconds", "Age of the oldest backup")
        backup_latest_age_seconds = self.gauge(
            "latest_age_seconds", "Age of the most recent backup"
        )
        backup_latest_duration_seconds = self.gauge(
            "latest_duration_seconds", "Duration the most recent backup took, in seconds"
        )
        backup_latest_compressed_size_bytes = self.gauge(
            "latest_compressed_size_bytes", "Size of the most recent backup, in bytes"
        )
        backup_latest_uncompressed_size_bytes = self.gauge(
            "latest_uncompressed_size_bytes",
            "Uncompressed size of the most recent backup, in bytes",
        )
        backup_total_compressed_size_bytes = self.gauge(
            "total_compressed_size_bytes", "Total compressed size of all backups, in bytes"
        )

        now = datetime.now(tz=timezone.utc)
        try:
            query = parse_qs(urlsplit(self.path).query)
            modified_env = os.environ.copy()
            if query.get("bucket") and len(query["bucket"]) == 1:
                bucket = query["bucket"][0]
                modified_env["WALG_S3_PREFIX"] = f"s3://{bucket}"
            else:
                config_file = configparser.RawConfigParser()
                config_file.read("/etc/zulip/zulip-secrets.conf")
                bucket = config_file["secrets"]["s3_backups_bucket"]

            backup_list_output = subprocess.check_output(
                ["env-wal-g", "backup-list", "--detail", "--json"],
                text=True,
                env=modified_env,
            )
            data = json.loads(backup_list_output) if backup_list_output else []
            backup_count(len(data), {"bucket": bucket})

            backup_total_compressed_size_bytes(
                sum(e["compressed_size"] for e in data), {"bucket": bucket}
            )

            if len(data) > 0:
                data.sort(key=lambda e: e["start_time"], reverse=True)
                latest = data[0]
                labels = {
                    "host": latest["hostname"],
                    "pg_version": str(latest["pg_version"]),
                    "bucket": bucket,
                }
                backup_latest_compressed_size_bytes(latest["compressed_size"], labels)
                backup_latest_uncompressed_size_bytes(latest["uncompressed_size"], labels)

                def t(key: str, e: dict[str, str] = latest) -> datetime:
                    return datetime.strptime(e[key], e["date_fmt"]).replace(tzinfo=timezone.utc)

                backup_earliest_age_seconds(
                    (now - t("start_time", data[-1])) / timedelta(seconds=1),
                    {
                        "host": data[-1]["hostname"],
                        "pg_version": data[-1]["pg_version"],
                        "bucket": bucket,
                    },
                )
                backup_latest_age_seconds((now - t("start_time")) / timedelta(seconds=1), labels)
                backup_latest_duration_seconds(
                    (t("finish_time") - t("start_time")) / timedelta(seconds=1), labels
                )
            backup_ok(1)
        except Exception as e:
            logging.exception(e)
        finally:
            self.print_metrics()
            self.log_message(
                "Served in %.2f seconds",
                (datetime.now(tz=timezone.utc) - now) / timedelta(seconds=1),
            )
            sys.stderr.flush()


logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
server = HTTPServer(("127.0.0.1", 9188), WalGPrometheusServer)
logging.info("Starting server...")
with contextlib.suppress(KeyboardInterrupt):
    server.serve_forever()

server.server_close()
logging.info("Stopping server...")
